from ai_collusion.gym_envs.envs.bertrand_competition import BertrandCompetitionDiscreteEnv, CostPerturbationWrapper, StackMDPWrapper, RestartExplorationRateWrapper, RLSupervisorQPricingWrapper, EvaluationAfterJohnsonConvergence
from stable_baselines3.common import logger
import os
from stable_baselines3.common.monitor import Monitor


def get_johnson_env(config_dict):
    # Parameters
    num_agents = 2
    k = 1
    m = config_dict['price_grid_length']

    # Hyperparameters
    alpha = 0.15  # Johnson param
    beta = 0.00001  # Johnson param
    delta = 0.95

    log_folder = config_dict["log_folder"]
    log = logger.configure(folder=log_folder, format_strings=["csv"])

    env = BertrandCompetitionDiscreteEnv(
        num_agents=num_agents,
        gamma=config_dict['gamma'],
        dp_type=config_dict['dp_type'],
        m=m,
        k=k,
        c_i=config_dict['marginal_cost'],
        max_steps=config_dict['max_steps'],
    )

    env = RLSupervisorQPricingWrapper(
        env,
        delta = delta,
        alpha = alpha,
        beta = beta,
        bbox_state_space_type=config_dict['bbox_state_space_type'],
        logger=log,
    )

    return env


def get_standard_env(config_dict):

    # Parameters
    num_agents = 2
    k = 1
    m = config_dict['price_grid_length']

    # Hyperparameters
    alpha = 0.15  # Johnson param
    beta = 0.00001  # Johnson param
    delta = 0.95

    log_folder = config_dict["log_folder"]
    log = logger.configure(folder=log_folder, format_strings=["csv"])

    env = BertrandCompetitionDiscreteEnv(
        num_agents=num_agents,
        gamma=config_dict['gamma'],
        dp_type=config_dict['dp_type'],
        m=m,
        c_i=config_dict['marginal_cost'],
        k=k,
        max_steps=config_dict['max_steps'],
        grid_upper_bound=config_dict['grid_upper_bound'],
    )

    if config_dict["experiment_type"] == "train_with_cost_perturbation":
        env = CostPerturbationWrapper(env)

    env = RLSupervisorQPricingWrapper(
        env,
        delta = delta,
        alpha = alpha,
        beta = beta,
        bbox_state_space_type=config_dict['bbox_state_space_type'],
        logger=log
    )

    env = StackMDPWrapper(
        env,
        tot_num_reward_steps=config_dict['tot_num_reward_steps'],
        tot_num_eq_steps=config_dict['tot_num_eq_steps'],
        frac_excluded_eq_steps=config_dict['frac_excluded_eq_steps'],
        reward_step_random_price_prob=config_dict['reward_step_random_price_prob'],
        critic_obs=config_dict['critic_obs'],
    )

    env = RestartExplorationRateWrapper(env, config_dict['q_restart_rate'])

    # if is_evaluation:
    #     env = Monitor(env)

    return env


def get_no_Stackelberg_env(config_dict):

    # Parameters
    num_agents = 2
    k = 1
    m = config_dict['price_grid_length']

    # Hyperparameters
    alpha = 0.15  # Johnson param
    beta = 0.00001  # Johnson param
    delta = 0.95

    log_folder = config_dict["log_folder"]
    log = logger.configure(folder=log_folder, format_strings=["csv"])


    env = BertrandCompetitionDiscreteEnv(
        num_agents=num_agents,
        gamma=config_dict['gamma'],
        dp_type=config_dict['dp_type'],
        m=m,
        k=k,
        c_i=config_dict['marginal_cost'],
        max_steps=config_dict['max_steps'],
    )

    env = RLSupervisorQPricingWrapper(
        env,
        delta = delta,
        alpha = alpha,
        beta = beta,
        bbox_state_space_type=config_dict['bbox_state_space_type'],
        log_freq=50000,
        logger = log
    )

    return env